model.py
Class DistributedInferenceBaseModel
_generate_output
直到生成max new tokens
outputs = self.model( input_ids, position_ids=position_ids, past_key_values=past_key_values, use_cache=True, enable_star_attn=True, ) # type: ignore
在最后一个rank上更新past_key_values
用logits获取新的token加入到output中
更新input和position
- 更新的数据是整个query再次传进去的
- query只添加到最后一个rank上,在query生成比较长的时候可能会出现负载均衡问题
class StarAttentionModel
- def _tokenize_and_partition_context
- 将输入padding为可整除并转化为tokens和positions两个tensor
- def _process_blockwise_context
- 每个rank做prefill,循环n次,每次计算一个block_size大小的,anchor block选择为rank0的第一个block
- 返回当前rank的kv cache
- def __call__
- 生成长文本的KV Cache
- 调用_tokenize_and_partition_context获取ctx_ids, position_ids
- 将ctx_ids拆分为world_size个tensor类型的ctx_ids_blocks,每个tensor的形状是[-1,1,block_size]
- position_ids同理转化为position_ids_blocks
- 调用_tokenize_and_partition_context生成当前rank的kv cache
- 生成Query
- embedding
- 调用_generate_output生成结果
- 生成长文本的KV Cache
- def _tokenize_and_partition_context